#!/usr/bin/env python3 -u
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
"""
Train a new model on one or across multiple GPUs.
"""

import argparse
import logging
import math
import random
import os
import sys
from itertools import chain
from typing import Dict, Optional, Any, List, Tuple, Callable
from fairseq_signals import distributed

# We need to setup root logger before importing any fairseq libraries.
logging.basicConfig(
    format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
    level = os.environ.get("LOGLEVEL", "INFO").upper(),
    stream = sys.stdout
)
logger = logging.getLogger("fairseq_cli.train")

import numpy as np
import torch
from fairseq_signals import tasks
from fairseq_signals.utils import utils, checkpoint_utils, options
from fairseq_signals.data import iterators, data_utils
from fairseq_signals.dataclass.configs import Config
from fairseq_signals.dataclass.utils import convert_namespace_to_omegaconf
from fairseq_signals.distributed import utils as distributed_utils
from fairseq_signals.utils.file_io import PathManager
from fairseq_signals.logging import meters, metrics, progress_bar
from fairseq_signals.trainer import Trainer
from omegaconf import DictConfig, OmegaConf

import pprint

def main(cfg: Config) -> None:
    if isinstance(cfg, argparse.Namespace):
        cfg = convert_namespace_to_omegaconf(cfg)
    
    utils.import_user_module(cfg.common)

    if distributed_utils.is_master(cfg.distributed_training) and "job_logging_cfg" in cfg:
        # make hydra logging work with ddp (see # see https://github.com/facebookresearch/hydra/issues/1126)
        logging.config.dictConfig(OmegaConf.to_container(cfg.job_logging_cfg))

    assert (
        cfg.dataset.max_tokens is not None or cfg.dataset.batch_size is not None
    ), "Must specify batch size either with --max-tokens or --batch-size"
    metrics.reset()

    if cfg.common.log_file is not None:
        handler = logging.FileHandler(filename = cfg.common.log_file)
        logger.addHandler(handler)
    
    np.random.seed(cfg.common.seed)
    random.seed(cfg.common.seed)
    utils.set_torch_seed(cfg.common.seed)

    if distributed_utils.is_master(cfg.distributed_training):
        checkpoint_utils.verify_checkpoint_directory(cfg.checkpoint.save_dir)
    
    # Print args
    logger.info(pprint.pformat(dict(cfg)))

    # if cfg.checkpoint.write_checkpoints_asynchronously:
    #     try:
    #         import iopath  # noqa: F401
    #     except ImportError:
    #         logging.exception(
    #             "Asynchronous checkpoint writing is specified but iopath is "
    #             "not installed: `pip install iopath`"
    #         )
    #         return

    # Setup the task, e.g. ecg_pretraining, ...
    task = tasks.setup_task(cfg.task)

    assert cfg.criterion, "Please specify criterion to train a model"

    # Build model and criterion
    model = task.build_model(cfg.model)
    criterion = task.build_criterion(cfg.criterion)

    logger.info(model)
    logger.info("task: {}".format(task.__class__.__name__))
    logger.info("model: {}".format(model.__class__.__name__))
    logger.info("criterion: {}".format(criterion.__class__.__name__))
    logger.info(
        "num. shared model params: {:,} (num. trained: {:,})".format(
            sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False)),
            sum(p.numel() for p in model.parameters() if not getattr(p, "expert", False) and p.requires_grad)
        )
    )

    logger.info(
        "num. expert model params: {} (num. trained: {})".format(
            sum(p.numel() for p in model.parameters() if getattr(p, "expert", False)),
            sum(p.numel() for p in model.parameters() if getattr(p, "expert", False) and p.requires_grad)
        )
    )

    # Load valid dataset (we load training data below, based on the latest checkpoint)
    # We load the valid dataset AFTER building the model
    if not cfg.dataset.disable_validation:
        data_utils.raise_if_valid_subsets_unintentionally_ignored(cfg)
        if cfg.dataset.combine_valid_subsets:
            task.load_dataset("valid", combine = True, epoch = 1)
        else:
            for valid_sub_split in cfg.dataset.valid_subset.split(","):
                task.load_dataset(valid_sub_split, combine = False, epoch = 1)

    # Build trainer
    trainer = Trainer(cfg, task, model, criterion)

    logger.info(
        "training on {} devices (GPUs)".format(
            cfg.distributed_training.distributed_world_size
        )
    )
    logger.info(
        "max tokens per device = {} and signals per device = {}".format(
            cfg.dataset.max_tokens,
            cfg.dataset.batch_size
        )
    )

    # Load the latest checkpoint if one is available and restore the
    # corresponding train iterator
    extra_state, epoch_itr = checkpoint_utils.load_checkpoint(
        cfg.checkpoint,
        trainer,
        # don't cache epoch iterators for sharded datasets
        disable_iterator_cache = task.has_sharded_data("train")
    )

    max_epoch = cfg.optimization.max_epoch or math.inf
    lr = trainer.get_lr()

    train_meter = meters.StopwatchMeter()
    train_meter.start()
    while epoch_itr.next_epoch_idx <= max_epoch:
        if lr <= cfg.optimization.stop_min_lr:
            logger.info(
                f"stopping training because current learning rate ({lr}) is smaller "
                "than or equal to minimum learning rate "
                f"(--stop-min-lr={cfg.optimization.stop_min_lr}"
            )
            break

        # train for one epoch
        valid_losses, should_stop = train(cfg, trainer, task, epoch_itr)
        if should_stop:
            break
        
        # only use first validation loss to update the learning rate
        lr = trainer.lr_step(epoch_itr.epoch, valid_losses[0])

        epoch_itr = trainer.get_train_iterator(
            epoch_itr.next_epoch_idx,
            # sharded data: get train iterator for next epoch,
            load_dataset = task.has_sharded_data("train"),
            # don't cache epoch iterators for sharded datasts
            disable_iterator_cache = task.has_sharded_data("train")
        )
    train_meter.stop()
    logger.info("done training in {:.1f} seconds".format(train_meter.sum))

    # ioPath implementation to wait for all asynchronous file writes to complete.
    # if cfg.checkpoint.write_checkpoints_asynchronously:
    #     logger.info(
    #         "ioPath PathManager waiting for all asynchronous checkpoint "
    #         "writes to finish."
    #     )
    #     PathManager.async_close()
    #     logger.info("ioPath PathManager finished waiting.")

def should_stop_early(cfg: DictConfig, valid_loss: float) -> bool:
    # skip check if no validation was done in the current epoch
    if valid_loss is None:
        return False
    if cfg.checkpoint.patience <= 0:
        return False
    
    def is_better(a, b):
        return a > b if cfg.checkpoint.maximize_best_checkpoint_metric else a < b
    
    prev_best = getattr(should_stop_early, "best", None)
    if prev_best is None or is_better(valid_loss, prev_best):
        should_stop_early.best = valid_loss
        should_stop_early.num_runs = 0
        return False
    else:
        should_stop_early.num_runs += 1
        if should_stop_early.num_runs >= cfg.checkpoint.patience:
            logger.info(
                "early stop since valid performance hasn't improved for last {} runs".format(
                    cfg.checkpoint.patience
                )
            )
            return True
        else:
            return False

@metrics.aggregate("train")
def train(
    cfg: DictConfig, trainer: Trainer, task: tasks.Task, epoch_itr
) -> Tuple[List[Optional[float]], bool]:
    """Train the model for one epoch and return validation losses."""
    # Initialize data iterator
    itr = epoch_itr.next_epoch_itr(
        fix_batches_to_gpus = cfg.distributed_training.fix_batches_to_gpus,
        shuffle = (epoch_itr.next_epoch_idx > cfg.dataset.curriculum)
    )
    update_freq = (
        cfg.optimization.update_freq[epoch_itr.epoch - 1]
        if epoch_itr.epoch <= len(cfg.optimization.update_freq)
        else cfg.optimization.update_freq[-1]
    )
    itr = iterators.GroupedIterator(itr, update_freq)
    progress = progress_bar.progress_bar(
        itr,
        log_format = cfg.common.log_format,
        log_file = cfg.common.log_file,
        log_interval = cfg.common.log_interval,
        epoch = epoch_itr.epoch,
        tensorboard_logdir = None,
        default_log_format = ("tqdm" if not cfg.common.no_progress_bar else "simple"),
        wandb_project = (
            cfg.common.wandb_project
            if distributed_utils.is_master(cfg.distributed_training)
            else None
        ),
        wandb_entity = (
            cfg.common.wandb_entity
            if distributed_utils.is_master(cfg.distributed_training)
            else None
        ),
        wandb_run_name = os.environ.get(
            "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
        ),
        azureml_logging=False
    )
    progress.update_config(_flatten_config(cfg))

    trainer.begin_epoch(epoch_itr.epoch)

    valid_subsets = [x.strip() for x in cfg.dataset.valid_subset.split(",")]
    should_stop = False
    num_updates = trainer.get_num_updates()
    logger.info("Start iterating over samples")
    for i, samples in enumerate(progress):
        with metrics.aggregate("train_inner"), torch.autograd.profiler.record_function(
            "train_step-%d" % i
        ):
            log_output = trainer.train_step(samples)

        if log_output is not None: # not OOM, overflow, ...
            # log mid-epoch stats
            num_updates = trainer.get_num_updates()
            if num_updates % cfg.common.log_interval == 0:
                stats = get_training_stats(metrics.get_smoothed_values("train_inner"))
                progress.log(stats, tag = "train_inner", step = num_updates)

                # reset mid-epoch stats after each log interval
                # the end-of-epoch stats will still be preserved
                metrics.reset_meters("train_inner")

        #NOTE utils for finding unused parameters
        # for n, p in trainer.model.named_parameters():
        #     if p.requires_grad and p.grad is None:
        #         print(n)
        # breakpoint()

        end_of_epoch = not itr.has_next()
        # NOTE hack for end epoch after first step
        # end_of_epoch = True
        valid_losses, should_stop = validate_and_save(
            cfg, trainer, task, epoch_itr, valid_subsets, end_of_epoch
        )

        if should_stop:
            break
    
    # log end-of-epoch stats
    logger.info("end of epoch {} (average epoch stats below)".format(epoch_itr.epoch))
    stats = get_training_stats(metrics.get_smoothed_values("train"))
    progress.print(stats, tag = "train", step = num_updates)

    # reset epoch-level meters
    metrics.reset_meters("train")
    return valid_losses, should_stop

def _flatten_config(cfg: DictConfig):
    config = OmegaConf.to_container(cfg)
    # remove any legacy Namespaces and replace with a single "args"
    namespace = None
    for k, v in list(config.items()):
        if isinstance(v, argparse.Namespace):
            namespace = v
            del config[k]
    if namespace is not None:
        config["args"] = vars(namespace)
    return config

def validate_and_save(
    cfg: DictConfig,
    trainer: Trainer,
    task: tasks.Task,
    epoch_itr,
    valid_subsets: List[str],
    end_of_epoch: bool
) -> Tuple[List[Optional[float]], bool]:
    num_updates = trainer.get_num_updates()
    max_update = cfg.optimization.max_update or math.inf

    # Stopping conditions (and an additional one based on validation loss later
    # on)
    should_stop = False
    if num_updates >= max_update:
        should_stop = True
        logger.info(
            f"Stopping training due to "
            f"num_updates: {num_updates} >= max_update: {max_update}"
        )
    
    training_time_hours = trainer.cumulative_training_time() / (60 * 60)
    if (
        cfg.optimization.stop_time_hours > 0
        and training_time_hours > cfg.optimization.stop_time_hours
    ):
        should_stop = True
        logger.info(
            f"Stopping training due to "
            f"cumulative_training_time: {training_time_hours} > "
            f"stop_time_hours: {cfg.optimization.stop_time_hours} hour(s)"
        )
    
    do_save = (
        (end_of_epoch and epoch_itr.epoch % cfg.checkpoint.save_interval == 0)
        or should_stop
        or (
            cfg.checkpoint.save_interval_updates > 0
            and num_updates > 0
            and num_updates % cfg.checkpoint.save_interval_updates == 0
            and num_updates >= cfg.dataset.validate_after_updates
        )
    )
    do_validate = (
        (not end_of_epoch and do_save) # validate during mid-epoch saves
        or (end_of_epoch and epoch_itr.epoch % cfg.dataset.validate_interval == 0)
        or should_stop
        or (
            cfg.dataset.validate_interval_updates > 0
            and num_updates > 0
            and num_updates % cfg.dataset.validate_interval_updates == 0
        )
    ) and not cfg.dataset.disable_validation and num_updates >= cfg.dataset.validate_after_updates

    # Validate
    valid_losses = [None]
    if do_validate:
        valid_losses = validate(cfg, trainer, task, epoch_itr, valid_subsets)
    
    should_stop |= should_stop_early(cfg, valid_losses[0])

    # save checkpoint
    if do_save or should_stop:
        checkpoint_utils.save_checkpoint(
            cfg.checkpoint, trainer, epoch_itr, valid_losses[0]
        )
        if torch.distributed.is_initialized():
            torch.distributed.barrier()

    return valid_losses, should_stop

def get_training_stats(stats: Dict[str, Any]) -> Dict[str, Any]:
    stats["wall"] = round(metrics.get_meter("default", "wall").elapsed_time, 0)
    return stats

def validate(
    cfg: DictConfig,
    trainer: Trainer,
    task: tasks.Task,
    epoch_itr,
    subsets: List[str]
) -> List[Optional[float]]:
    """Evaluate the model on the validation set(s) and return the losses"""

    if cfg.dataset.fixed_validation_seed is not None:
        # set fixed seed for every validation
        utils.set_torch_seed(cfg.dataset.fixed_validation_seed)
    
    trainer.begin_valid_epoch(epoch_itr.epoch)
    valid_losses = []
    for subset in subsets:
        logger.info('begin validation on "{}" subset'.format(subset))

        # Initialize data iterator
        itr = trainer.get_valid_iterator(subset).next_epoch_itr(
            shuffle = False, set_dataset_epoch = False # use a fixed valid set
        )
        progress = progress_bar.progress_bar(
            itr,
            log_format = cfg.common.log_format,
            log_interval = cfg.common.log_interval,
            epoch = epoch_itr.epoch,
            prefix = f"valid on '{subset}' subset",
            tensorboard_logdir = None,
            default_log_format = ("tqdm" if not cfg.common.no_progress_bar else "simple"),
            wandb_project = (
                cfg.common.wandb_project
                if distributed_utils.is_master(cfg.distributed_training)
                else None
            ),
            wandb_entity = (
                cfg.common.wandb_entity
                if distributed_utils.is_master(cfg.distributed_training)
                else None
            ),
            wandb_run_name = os.environ.get(
                "WANDB_NAME", os.path.basename(cfg.checkpoint.save_dir)
            )
        )

        # create a new root metrics aggregator so validation metrics
        # don't pollute other aggregators (e.g., train meters)
        with metrics.aggregate(new_root=True) as agg:
            for i, sample in enumerate(progress):
                if cfg.dataset.max_valid_steps is not None and i > cfg.dataset.max_valid_steps:
                    break
                trainer.valid_step(sample, subset=subset)

        # log validation stats
        # stats = get_valid_stats(cfg, trainer, agg.get_smoothed_values())
        #XXX log validation stats for each validation set
        stats = get_valid_stats(cfg, trainer, subset, agg.get_smoothed_values())

        if hasattr(task, "post_validate"):
            task.post_validate(
                model=trainer.get_model(),
                log_output=stats,
                agg=agg,
                num_updates=trainer.get_num_updates()
            )

        progress.print(stats, tag=subset, step=trainer.get_num_updates())

        valid_losses.append(stats[cfg.checkpoint.best_checkpoint_metric])
    return valid_losses

#XXX logging best valid stats for each validation set
def get_valid_stats(
    cfg: DictConfig, trainer: Trainer, subset: str, stats: Dict[str, Any]
) -> Dict[str, Any]:
# def get_valid_stats(
#     cfg: DictConfig, trainer: Trainer, stats: Dict[str, Any]
# ) -> Dict[str, Any]:
    stats["num_updates"] = trainer.get_num_updates()

    #XXX logging best valid stats for each validation set
    if not hasattr(get_valid_stats, "best"):
        get_valid_stats.best = dict()

    prev_best = getattr(get_valid_stats, "best").get(
        subset, stats[cfg.checkpoint.best_checkpoint_metric]
    )
    best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
    get_valid_stats.best[subset] = best_function(
        stats[cfg.checkpoint.best_checkpoint_metric], prev_best
    )

    key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
    stats[key] = get_valid_stats.best[subset]

    # logging best stats according to the best validation step
    # if hasattr(checkpoint_utils.save_checkpoint, "best"):
    #     key = "best_{0}".format(cfg.checkpoint.best_checkpoint_metric)
    #     best_function = max if cfg.checkpoint.maximize_best_checkpoint_metric else min
    #     stats[key] = best_function(
    #         checkpoint_utils.save_checkpoint.best,
    #         stats[cfg.checkpoint.best_checkpoint_metric],
    #     )
    return stats

def cli_main(
    modify_parser: Optional[Callable[[argparse.ArgumentParser], None]] = None,
) -> None:
    parser = options.get_training_parser()
    args = options.parse_args_and_arch(parser, modify_parser=modify_parser)

    cfg = convert_namespace_to_omegaconf(args)

    if args.profile:
        with torch.cuda.profiler.profile():
            with torch.autograd.profiler.emit_nvtx():
                distributed_utils.call_main(cfg, main)
    else:
        distributed_utils.call_main(cfg, main)
    

if __name__ == "__main__":
    cli_main()